Here, I will examine the utility of sampling and variational inference for inferring values from a simple network diffusion model on Erdos-Renyi random graphs. The primary aim of this document is to assess how well inference can scale as the network grows in size and as topology -- in this case, connection probability -- change.
First, check environment to ensure all packages needed are present and document their versions.
using Pkg
Pkg.status()
Status `~/Projects/NetworkTopology/Project.toml` [76274a88] Bijectors v0.8.14 [a93c6f00] DataFrames v0.22.5 [0c46a032] DifferentialEquations v6.16.0 [31c24e10] Distributions v0.24.13 [7073ff75] IJulia v1.23.1 [093fc24a] LightGraphs v1.3.5 [c7f686f2] MCMCChains v4.7.0 [91a5bcdd] Plots v1.10.4 [37e2e3b7] ReverseDiff v1.5.0 [f3b207a7] StatsPlots v0.14.19 [fce5fe82] Turing v0.15.10 [e88e6eb3] Zygote v0.6.3
using Random
using DifferentialEquations
using Turing
using Plots
using StatsPlots
using MCMCChains
using LightGraphs
Random.seed!(1)
MersenneTwister(UInt32[0x00000001], Random.DSFMT.DSFMT_state(Int32[1749029653, 1072851681, 1610647787, 1072862326, 1841712345, 1073426746, -198061126, 1073322060, -156153802, 1073567984 … 1977574422, 1073209915, 278919868, 1072835605, 1290372147, 18858467, 1815133874, -1716870370, 382, 0]), [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], UInt128[0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000 … 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000, 0x00000000000000000000000000000000], 1002, 0)
The first step in defining our model will be to initialise a graph on which to run the model. We do this using LightGraphs to generate a Erdos-Renyi random graph of size N.
function make_graph(N::Int64, P::Float64)
G = erdos_renyi(N, P)
L = laplacian_matrix(G)
A = adjacency_matrix(G)
return L, A
end
N = 5
P = 0.5
L, A = make_graph(N, P);
The second step of the modelling process will be to define the ODE model. For network diffusion, this is given by:
$$ \frac{d\mathbf{u}}{dt} = -\rho \mathbf{L} \mathbf{u} $$We can set this up as a julia function as follows:
function NetworkDiffusion(u, p, t)
du = -p * L * u
end
NetworkDiffusion (generic function with 1 method)
To run a simulation, we set some initial conditions and define an ODEProblem using DifferentialEquations
u0 = rand(N)
p = 2.0
t_span = (0.0,1.0);
problem = ODEProblem(NetworkDiffusion, u0, (0.0,1.0), p);
sol = solve(problem, AutoTsit5(Rosenbrock23()), saveat=0.05);
And we can view the solution.
plotly()
plot(sol)
┌ Info: For saving to png with the Plotly backend PlotlyBase has to be installed. └ @ Plots /home/chaggar/.julia/packages/Plots/6EMd6/src/backends.jl:372
Now that we have a model, we generate some data and start to using Turing to perform inference.
To do this, we should define a generative model.
Our data $\mathbf{y}$ is given by a normal distribution centered around our model $f(\mathbf{u0}, \rho)$ with variance $\sigma$.
$$\mathbf{y} = \mathcal{N}(f(\mathbf{u0}, \rho), \sigma)$$and we assume our paramters are generated from the following distributions:
$$\sigma \approx \Gamma^{-1}(2, 3)$$$$\rho \approx \mathcal{N}(5,10,[0,10])$$ $$\mathbf{u0} \approx \mathcal{N}(0,2,[0,1])$$
We can make this into a Turing model.
using Base.Threads
Turing.setadbackend(:forwarddiff)
:forwarddiff
@model function fitode(data, problem)
u_n = size(data)[1]
σ ~ InverseGamma(2, 3) # ~ is the tilde character
ρ ~ truncated(Normal(5,10.0),0.0,10)
u ~ filldist(truncated(Normal(0.5,2.0),0.0,1.0), u_n)
prob = remake(problem, u0=u, p=ρ)
predicted = solve(prob, Tsit5(),saveat=0.05)
@threads for i = 1:length(predicted)
data[:,i] ~ MvNormal(predicted[i], σ)
end
end
fitode (generic function with 1 method)
To fit this model, we first need to generate some data. We can then feed in our data and our model into the Turing model and begin to sample from it.
For now, we'll just use the data generated form our ODE solution above.
data = Array(sol)
5×21 Array{Float64,2}:
0.192454 0.192454 0.192454 0.192454 … 0.192454 0.192454 0.192454
0.0571363 0.10259 0.138278 0.166392 0.271119 0.27179 0.272306
0.123272 0.143615 0.161797 0.177788 0.262349 0.263475 0.264431
0.503282 0.463112 0.430598 0.404193 0.287519 0.285993 0.284667
0.403083 0.377456 0.356101 0.338401 0.265787 0.265516 0.265369
We can now perform inference. First by initialising our fit function with synthetic data and our ODE problem. We can call initialise multiple chains to sampline in parallel -- here we use 10 chains.
Once the sampling has completed, we can plot the chains to visualise convergence and posterior distributions of parameters.
model = fitode(data, problem)
chain = sample(model, NUTS(.65), MCMCThreads(), 1_000, 10, progress=false);
┌ Info: Found initial step size │ ϵ = 0.2 └ @ Turing.Inference /home/chaggar/.julia/packages/Turing/XLLTf/src/inference/hmc.jl:188 ┌ Info: Found initial step size │ ϵ = 0.2 └ @ Turing.Inference /home/chaggar/.julia/packages/Turing/XLLTf/src/inference/hmc.jl:188 ┌ Info: Found initial step size │ ϵ = 0.2 └ @ Turing.Inference /home/chaggar/.julia/packages/Turing/XLLTf/src/inference/hmc.jl:188 ┌ Info: Found initial step size │ ϵ = 0.2 └ @ Turing.Inference /home/chaggar/.julia/packages/Turing/XLLTf/src/inference/hmc.jl:188 ┌ Info: Found initial step size │ ϵ = 0.05 └ @ Turing.Inference /home/chaggar/.julia/packages/Turing/XLLTf/src/inference/hmc.jl:188 ┌ Info: Found initial step size │ ϵ = 0.2 └ @ Turing.Inference /home/chaggar/.julia/packages/Turing/XLLTf/src/inference/hmc.jl:188 ┌ Info: Found initial step size │ ϵ = 0.2 └ @ Turing.Inference /home/chaggar/.julia/packages/Turing/XLLTf/src/inference/hmc.jl:188 ┌ Info: Found initial step size │ ϵ = 0.2 └ @ Turing.Inference /home/chaggar/.julia/packages/Turing/XLLTf/src/inference/hmc.jl:188 ┌ Info: Found initial step size │ ϵ = 0.4 └ @ Turing.Inference /home/chaggar/.julia/packages/Turing/XLLTf/src/inference/hmc.jl:188 ┌ Info: Found initial step size │ ϵ = 0.2 └ @ Turing.Inference /home/chaggar/.julia/packages/Turing/XLLTf/src/inference/hmc.jl:188 ┌ Warning: The current proposal will be rejected due to numerical error(s). │ isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false) └ @ AdvancedHMC /home/chaggar/.julia/packages/AdvancedHMC/MIxdK/src/hamiltonian.jl:47
summarize(chain)
plot(chain)
Summary Statistics parameters mean std naive_se mcse ess rhat Symbol Float64 Float64 Float64 Float64 Float64 Float64 u[1] 0.1925 0.0065 0.0001 0.0001 8935.1541 1.0000 u[2] 0.0560 0.0186 0.0002 0.0002 5295.5100 1.0000 u[3] 0.1220 0.0181 0.0002 0.0002 6050.6603 1.0001 u[4] 0.5048 0.0156 0.0002 0.0002 5183.2113 1.0001 u[5] 0.4041 0.0141 0.0001 0.0002 6152.7354 1.0007 ρ 2.0487 0.1969 0.0020 0.0023 5428.5926 1.0003 σ 0.0300 0.0030 0.0000 0.0000 7912.8041 0.9997
scatter(vcat(data[:,1],p,0))
scatter!(mean(chain).nt.mean)
We can see from the plots and the chain summary that the chains converge and produce consistent estimates of the posterior distributions. Importantly, the posterior estimates closely correspond to the true model parameters.
With the ODE model and Turing model setup, we can begin to experiment with how inference is affected by changes to network topology and size.
In this first experiment, we will test how well inference scales when we increase the size of the network used to simulate network diffusion.
We can do this by initalising a new network with size N and plugging this into our ODEProblem and Turing model.
For the benefit of speed, we will only use one chain for the following examples. In practice, using multiple chains should not take much longer and can be implemented easily, as above.
N = 10
P = 0.5
L, A = make_graph(N, P);
problem = ODEProblem(NetworkDiffusion, rand(10), (0.0,1.0), p);
data = Array(solve(problem, AutoTsit5(Rosenbrock23()), saveat=0.05))
plot(data')
model = fitode(data, problem)
#chain = sample(model, NUTS(.65), MCMCThreads(), 1_000, 10, progress=false)
chain = sample(model, NUTS(0.65), 1_000);
┌ Info: Found initial step size │ ϵ = 0.2 └ @ Turing.Inference /home/chaggar/.julia/packages/Turing/XLLTf/src/inference/hmc.jl:188 Sampling: 28%|███████████▋ | ETA: 0:00:29┌ Warning: The current proposal will be rejected due to numerical error(s). │ isfinite.((θ, r, ℓπ, ℓκ)) = (true, false, false, false) └ @ AdvancedHMC /home/chaggar/.julia/packages/AdvancedHMC/MIxdK/src/hamiltonian.jl:47 Sampling: 100%|█████████████████████████████████████████| Time: 0:00:32
summarize(chain)
Summary Statistics parameters mean std naive_se mcse ess rhat Symbol Float64 Float64 Float64 Float64 Float64 Float64 u[1] 0.6185 0.0118 0.0004 0.0002 1052.7136 0.9990 u[2] 0.6781 0.0100 0.0003 0.0003 981.1124 1.0005 u[3] 0.3143 0.0119 0.0004 0.0002 1203.0726 0.9991 u[4] 0.5936 0.0123 0.0004 0.0005 839.8567 0.9990 u[5] 0.7626 0.0088 0.0003 0.0002 850.3368 1.0002 u[6] 0.9060 0.0114 0.0004 0.0003 1058.6303 0.9995 u[7] 0.2400 0.0124 0.0004 0.0004 1094.3091 0.9996 u[8] 0.5659 0.0113 0.0004 0.0004 1040.8340 0.9991 u[9] 0.6441 0.0110 0.0003 0.0003 847.4087 0.9990 u[10] 0.3662 0.0120 0.0004 0.0003 926.4561 0.9992 ρ 2.0066 0.0676 0.0021 0.0020 568.2702 0.9999 σ 0.0150 0.0011 0.0000 0.0000 1182.1550 0.9993
plot(chain, seriestype = (:traceplot, :histogram))
scatter(vcat(data[:,1],p,0))
scatter!(mean(chain).nt.mean)
The results are stable, with tight distributions around the mean despite taking relatively few samples from the posterior (1000).
N = 25
P = 0.5
L, A = make_graph(N, P);
u0 = rand(N);
problem = ODEProblem(NetworkDiffusion, u0, (0.0,1.0), p);
data = Array(solve(problem, Tsit5(), saveat=0.05))
plot(data')
model = fitode(data, problem)
chain = sample(model, NUTS(0.65), 1_000)
#chain = sample(model, NUTS(0.65), MCMCThreads(), 1000, 10, progres=false);
┌ Info: Found initial step size
│ ϵ = 0.2
└ @ Turing.Inference /home/chaggar/.julia/packages/Turing/XLLTf/src/inference/hmc.jl:188
Sampling: 100%|█████████████████████████████████████████| Time: 0:02:21
plot(chain, seriestype = (:traceplot, :histogram))
scatter(vcat(data[:,1],p,0))
scatter!(mean(chain).nt.mean)
The results are still stable, with tight distributions. It is promising that the inference remains fast, despite increasing the number of nodes. This is likely due to simplicity of the model -- making auto diff easier and faster.
It is also encouraging that the results are accurate even though there are fewer informative data points available. That is, the diffusion time course toward equillibrium has decreased from 5 nodes to 10 to 25. This is most certainly due to the increased number of connections between nodes (higher mean degree). In the next case, there should be even fewer informative data points available.
N = 50
P = 0.5
L, A = make_graph(N, P);
problem = ODEProblem(NetworkDiffusion, rand(N), (0.0,1.0), p);
data = Array(solve(problem, Tsit5(), saveat=0.05))
plot(data')
model = fitode(data, problem)
chain = sample(model, NUTS(0.65), 1_000)
#chain = sample(model, NUTS(0.65), MCMCThreads(), 1000, 10, progres=false);
┌ Info: Found initial step size
│ ϵ = 0.05
└ @ Turing.Inference /home/chaggar/.julia/packages/Turing/XLLTf/src/inference/hmc.jl:188
Sampling: 7%|███ | ETA: 9:32:28
At N = 50, integrated over time span (0.0,1.0) and evaluated at 0.05 steps, inference using the NUTS sampler was slow and appears to not be converging. The ETA for inference continued to increase toward 9.5 hrs before I terminated. There are at least two likely reasons for this. 1) The model has reached a critical level of complexity such that the NUTS sampler is not able to efficienctly explore the posterior parameters space. 2) there is not sufficient data to be able for the parameters to be identified, resulting in non-convergence.
To test whether (1) or (2) is true, I will run a simulation integrating over a shorter time window and evaluating at 0.0125 steps.
problem = ODEProblem(NetworkDiffusion, rand(N), (0.0,0.25), p);
data = Array(solve(problem, AutoTsit5(Rosenbrock23()), saveat=0.0125))
plot(data')
@model function fitode_short(data, problem)
u_n = size(data)[1]
σ ~ InverseGamma(2, 3) # ~ is the tilde character
ρ ~ truncated(Normal(5,10.0),0.0,10)
u ~ filldist(truncated(Normal(0.5,2.0),0.0,1.0), u_n)
prob = remake(problem, u0=u, p=ρ)
predicted = solve(prob, Tsit5(),saveat=0.0125)
@threads for i = 1:length(predicted)
data[:,i] ~ MvNormal(predicted[i], σ)
end
end
fitode_short (generic function with 1 method)
model_long = fitode_short(data, problem)
chain = sample(model_long, NUTS(0.65), 1000)
┌ Info: Found initial step size
│ ϵ = 0.05
└ @ Turing.Inference /home/chaggar/.julia/packages/Turing/XLLTf/src/inference/hmc.jl:188
Sampling: 100%|█████████████████████████████████████████| Time: 0:10:03
Chains MCMC chain (1000×64×1 Array{Float64,3}):
Iterations = 1:1000
Thinning interval = 1
Chains = 1
Samples per chain = 1000
parameters = u[1], u[2], u[3], u[4], u[5], u[6], u[7], u[8], u[9], u[10], u[11], u[12], u[13], u[14], u[15], u[16], u[17], u[18], u[19], u[20], u[21], u[22], u[23], u[24], u[25], u[26], u[27], u[28], u[29], u[30], u[31], u[32], u[33], u[34], u[35], u[36], u[37], u[38], u[39], u[40], u[41], u[42], u[43], u[44], u[45], u[46], u[47], u[48], u[49], u[50], ρ, σ
internals = acceptance_rate, hamiltonian_energy, hamiltonian_energy_error, is_accept, log_density, lp, max_hamiltonian_energy_error, n_steps, nom_step_size, numerical_error, step_size, tree_depth
Summary Statistics
parameters mean std naive_se mcse ess rhat
Symbol Float64 Float64 Float64 Float64 Float64 Float64
u[1] 0.2889 0.0025 0.0001 0.0001 1683.8969 1.0014
u[2] 0.6668 0.0026 0.0001 0.0001 1998.6946 0.9993
u[3] 0.3163 0.0022 0.0001 0.0000 1997.6755 1.0002
u[4] 0.4265 0.0025 0.0001 0.0001 2066.0053 0.9994
u[5] 0.8937 0.0022 0.0001 0.0000 2555.3926 0.9992
u[6] 0.4865 0.0025 0.0001 0.0000 2050.4011 0.9999
u[7] 0.8276 0.0023 0.0001 0.0000 2078.1141 0.9992
u[8] 0.5499 0.0022 0.0001 0.0000 2281.4740 1.0003
u[9] 0.8623 0.0027 0.0001 0.0001 2466.2154 0.9997
u[10] 0.4806 0.0025 0.0001 0.0001 2967.2517 0.9990
u[11] 0.0422 0.0024 0.0001 0.0001 2391.9083 0.9995
u[12] 0.7958 0.0025 0.0001 0.0001 1566.6635 1.0001
u[13] 0.9934 0.0024 0.0001 0.0001 1093.9042 1.0055
u[14] 0.5682 0.0022 0.0001 0.0000 1880.2845 1.0017
u[15] 0.5005 0.0023 0.0001 0.0000 2626.9501 1.0005
u[16] 0.6487 0.0025 0.0001 0.0000 2387.2433 0.9990
u[17] 0.9549 0.0026 0.0001 0.0001 1867.2469 0.9991
u[18] 0.4633 0.0024 0.0001 0.0001 1919.5743 0.9991
u[19] 0.9588 0.0025 0.0001 0.0000 2616.4177 0.9992
u[20] 0.0191 0.0025 0.0001 0.0001 1426.3207 1.0019
u[21] 0.2291 0.0024 0.0001 0.0000 2830.0746 0.9992
u[22] 0.3822 0.0025 0.0001 0.0000 2317.6966 0.9999
u[23] 0.0862 0.0026 0.0001 0.0000 2655.7836 0.9990
⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮
29 rows omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
u[1] 0.2840 0.2871 0.2889 0.2906 0.2936
u[2] 0.6618 0.6651 0.6668 0.6685 0.6721
u[3] 0.3121 0.3148 0.3163 0.3177 0.3207
u[4] 0.4215 0.4249 0.4265 0.4282 0.4315
u[5] 0.8892 0.8923 0.8937 0.8951 0.8980
u[6] 0.4816 0.4848 0.4866 0.4882 0.4912
u[7] 0.8233 0.8260 0.8276 0.8292 0.8322
u[8] 0.5457 0.5484 0.5499 0.5513 0.5543
u[9] 0.8574 0.8604 0.8623 0.8642 0.8677
u[10] 0.4757 0.4789 0.4807 0.4823 0.4853
u[11] 0.0374 0.0406 0.0422 0.0438 0.0467
u[12] 0.7909 0.7941 0.7958 0.7975 0.8006
u[13] 0.9887 0.9918 0.9934 0.9950 0.9981
u[14] 0.5638 0.5667 0.5683 0.5697 0.5722
u[15] 0.4961 0.4988 0.5005 0.5021 0.5048
u[16] 0.6441 0.6470 0.6487 0.6505 0.6537
u[17] 0.9502 0.9531 0.9549 0.9567 0.9599
u[18] 0.4586 0.4617 0.4632 0.4650 0.4682
u[19] 0.9540 0.9571 0.9588 0.9604 0.9638
u[20] 0.0141 0.0174 0.0190 0.0209 0.0239
u[21] 0.2246 0.2275 0.2292 0.2308 0.2338
u[22] 0.3772 0.3806 0.3822 0.3838 0.3872
u[23] 0.0812 0.0844 0.0862 0.0879 0.0915
⋮ ⋮ ⋮ ⋮ ⋮ ⋮
29 rows omitted
scatter(vcat(data[:,1],p,0))
scatter!(mean(chain).nt.mean, yerr=std(Array(chain), dims=1))
The model converges for all parameters and produces highly accurate results.